from lm import LM
from openai import OpenAI
import openai
import sys
import time
import os
import numpy as np
import logging
import concurrent.futures
from functools import partial
from dotenv import load_dotenv

load_dotenv()
sys.setrecursionlimit(10000)

MAX_NUM_ERROR = 5

def async_process(fn, inps, workers=10):
    with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
        return list(executor.map(fn, inps))

class OpenAIModel(LM):

    def __init__(self, model_name, cache_file=None, key_path="api.key"):
        self.model_name = model_name
        self.key_path = key_path
        self.temp = 0.7
        self.save_interval = 100
        super().__init__(cache_file)

    def load_model(self):
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        self.model = self.model_name

    def _generate(self, prompt, max_sequence_length=2048, max_output_length=128, few_shot=False):
        if self.add_n % self.save_interval == 0:
            self.save_cache()

        if self.model_name == "ChatGPT":
            call_fn = partial(self.call_ChatGPT, max_new_tokens=max_output_length)
        else:
            call_fn = partial(self.call_GPT3, max_tokens=max_output_length)

        if isinstance(prompt, list):
            message = [[{"role": "user", "content": p}] for p in prompt]
            return async_process(call_fn, message, workers=len(message))
        else:
            if few_shot and self.model_name == 'ChatGPT':
                message = []
                split_shots = prompt.split('\n\n')
                for i, shot in enumerate(split_shots):
                    split_s = shot.split('\n')
                    message.append({'role': 'user', 'content': split_s[0].strip()})
                    if (i + 1) < len(split_shots):
                        message.append({'role': 'assistant', 'content': '\n'.join(split_s[1:]).strip()})
            else:
                message = [{"role": "user", "content": prompt}]
            return call_fn(message)

    def call_ChatGPT(self, message, max_len=1024, max_new_tokens=512):
        model_name = "gpt-4o-mini-2024-07-18"
        response = None
        received = False
        num_rate_errors = 0
        while not received:
            try:
                response = self.client.chat.completions.create(
                    model=model_name,
                    messages=message,
                    max_tokens=max_new_tokens,
                    temperature=self.temp
                )
                received = True
            except Exception as e:
                num_rate_errors += 1
                if isinstance(e, openai.BadRequestError):
                    logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{message}\n\n")
                    raise
                logging.error("API error: %s (%d). Waiting %d sec" % (e, num_rate_errors, 2 ** num_rate_errors))
                time.sleep(2 ** num_rate_errors)
                if num_rate_errors > MAX_NUM_ERROR:
                    return None, None
        return response.choices[0].message.content, response

    def call_GPT3(self, message, max_len=512, max_tokens=512, num_log_probs=0, echo=False, verbose=False):
        model_name = "gpt-4o-mini-2024-07-18"
        prompt = message[-1]['content']
        response = None
        received = False
        num_rate_errors = 0
        while not received:
            try:
                response = self.client.chat.completions.create(
                    model=model_name,
                    messages=[{"role": "user", "content": prompt}],
                    max_tokens=max_tokens,
                    temperature=self.temp,
                )
                received = True
            except Exception as e:
                num_rate_errors += 1
                if isinstance(e, openai.BadRequestError):
                    logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n")
                    raise
                logging.error("API error: %s (%d). Waiting %d sec" % (e, num_rate_errors, 2 ** num_rate_errors))
                time.sleep(2 ** num_rate_errors)
                if num_rate_errors > MAX_NUM_ERROR:
                    return None, None
        return response.choices[0].message.content, response


# from lm import LM
# # import openai
# from openai import OpenAI
# import openai
# import sys
# import time
# import os
# import numpy as np
# import logging
# import concurrent.futures
# from functools import partial

# from dotenv import load_dotenv
# load_dotenv()

# sys.setrecursionlimit(10000)  # Increase the recursion limit

# MAX_NUM_ERROR = 5

# def async_process(fn,inps,workers=10):
#     with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
#         out = list(executor.map(fn,inps))
#     return out

# class OpenAIModel(LM):

#     def __init__(self, model_name, cache_file=None, key_path="api.key"):
#         self.model_name = model_name
#         self.key_path = key_path
#         self.temp = 0.7
#         self.save_interval = 100
#         super().__init__(cache_file)

#     def load_model(self):
#         # load api key
#         # key_path = self.key_path
#         # assert os.path.exists(key_path), f"Please place your OpenAI APT Key in {key_path}."
#         # with open(key_path, 'r') as f:
#         #     api_key = f.readline()
#         # self.client = openai.OpenAI(api_key = api_key.strip())
#         self.client = OpenAI(api_key = os.getenv("OPENAI_API_KEY"))
#         self.model = self.model_name

#     def _generate(self, prompt, max_sequence_length=2048, max_output_length=128,few_shot=False):
#         if self.add_n % self.save_interval == 0:
#             self.save_cache()
#         # return a tuple of string (generated text) and metadata (any format)
#         # This should be about generating a response from the prompt, no matter what the application is
#         if self.model_name == "ChatGPT":
#             call_fn = partial(self.call_ChatGPT,max_new_tokens = max_output_length)
#         else:
#             call_fn = partial(self.call_GPT3,max_tokens = max_output_length)
#         if isinstance(prompt,list):
#             is_list = True
#             message = [[{"role": "user", "content": p}] for p in prompt]
#         else:
#             is_list = False
#             if few_shot and self.model_name == 'ChatGPT':
#                 message = []
#                 split_shots = prompt.split('\n\n')
#                 for i,shot in enumerate(split_shots):
#                     split_s = shot.split('\n')
#                     message.append({'role':'user','content':split_s[0].strip()})
#                     if (i+1) < len(split_shots):
#                         message.append({'role':'assistant','content':'\n'.join(split_s[1:]).strip()})
#             else:
#                 message = [{"role": "user", "content": prompt}]
#         if is_list:
#             out = async_process(call_fn,message,workers = len(message))
#         else:
#             out = call_fn(message)
#         return out 

#     def call_ChatGPT(self,message, max_len=1024,max_new_tokens=512):
#         # call GPT-3 API until result is provided and then return it
#         # model_name = "gpt-3.5-turbo-0125"
#         model_name = "gpt-4o-mini-2024-07-18"
#         # model_name = "gpt-4-turbo"
#         response = None
#         received = False
#         num_rate_errors = 0
#         while not received:
#             try:
#                 response = self.client.chat.completions.create(model=model_name,
#                 # response = self.client.ChatCompletion.create(model=model_name,
#                                                     messages=message,
#                                                     max_tokens = max_new_tokens,
#                                                     temperature=self.temp)
#                 received = True
#             except:
#                 num_rate_errors += 1
#                 error = sys.exc_info()[0]
#                 if error == openai.BadRequestError:
#                     # something is wrong: e.g. prompt too long
#                     logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{message}\n\n")
#                     assert False
#                 logging.error("API error: %s (%d). Waiting %dsec" % (error, num_rate_errors, np.power(2, num_rate_errors)))
#                 time.sleep(np.power(2, num_rate_errors))
#             if num_rate_errors > MAX_NUM_ERROR:
#                 return None,None
#         return response.choices[0].message.content,response


#     def call_GPT3(self,message,max_len=512, max_tokens = 512, num_log_probs=0, echo=False, verbose=False):
#         # call GPT-3 API until result is provided and then return it
#         # model_name="gpt-3.5-turbo-0125"
#         model_name = "gpt-4o-mini-2024-07-18"
#         # model_name = "gpt-4-turbo"

#         response = None
#         received = False
#         num_rate_errors = 0
#         prompt = message[-1]['content']
#         while not received:
#             try:
#                 response = self.client.chat.completions.create(model=model_name,
#                 # response = self.client.ChatCompletion.create(model=model_name,
#                                                     prompt=prompt,
#                                                     max_tokens = max_tokens,
#                                                     temperature=self.temp,
#                                                     logprobs=num_log_probs,
#                                                     echo=echo)
#                 received = True
#             except:
#                 error = sys.exc_info()[0]
#                 num_rate_errors += 1
#                 if error == openai.BadRequestError:
#                     # something is wrong: e.g. prompt too long
#                     logging.critical(f"InvalidRequestError\nPrompt passed in:\n\n{prompt}\n\n")
#                     assert False
#                 logging.error("API error: %s (%d)" % (error, num_rate_errors))
#                 time.sleep(np.power(2, num_rate_errors))
#             if num_rate_errors > MAX_NUM_ERROR:
#                 return None,None
#         return response.choices[0].text,response

# File: test_openai_dotenv.py

# import openai
# from dotenv import load_dotenv
# import os

# # Load environment variables from .env file
# load_dotenv()

# # Get your API key from the environment
# api_key = os.getenv("OPENAI_API_KEY")

# if not api_key:
#     raise ValueError("❌ OPENAI_API_KEY not found in .env file!")

# openai.api_key = api_key

# try:
#     response = openai.ChatCompletion.create(
#         model="gpt-4o-mini-2024-07-18",  # You can try "gpt-3.5-turbo" or "gpt-4-turbo" too
#         messages=[
#             {"role": "system", "content": "You are a helpful assistant."},
#             {"role": "user", "content": "Respond with a 3-word sentence."}
#         ]
#     )
#     print("✅ Success!")
#     print("Response:", response['choices'][0]['message']['content'])

# except openai.error.OpenAIError as e:
#     print("❌ OpenAI API Error:", e)

# except Exception as e:
#     print("❌ Other Error:", e)

# from lm import LM
# from openai import OpenAI
# import openai
# import sys
# import time
# import os
# import numpy as np
# import logging
# import concurrent.futures
# from functools import partial
# from dotenv import load_dotenv

# load_dotenv()
# sys.setrecursionlimit(10000)

# MAX_NUM_ERROR = 5
# DEFAULT_MODEL_NAME = "gpt-4o-mini-2024-07-18"

# def async_process(fn, inps, workers=10):
#     with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
#         return list(executor.map(fn, inps))

# class OpenAIModel(LM):
#     def __init__(self, model_name=DEFAULT_MODEL_NAME, cache_file=None):
#         self.model_name = model_name
#         self.temp = 0.7
#         self.save_interval = 100
#         super().__init__(cache_file)

#     def load_model(self):
#         self.model = self.model_name  # No client instantiation here

#     def _generate(self, prompt, max_sequence_length=2048, max_output_length=128, few_shot=False):
#         if self.add_n % self.save_interval == 0:
#             self.save_cache()

#         call_fn = partial(
#             self.call_ChatGPT if self.model_name == "ChatGPT" else self.call_GPT3,
#             max_new_tokens=max_output_length
#         )

#         if isinstance(prompt, list):
#             message = [[{"role": "user", "content": p}] for p in prompt]
#             return async_process(call_fn, message, workers=len(message))
#         else:
#             if few_shot and self.model_name == "ChatGPT":
#                 message = []
#                 split_shots = prompt.split('\n\n')
#                 for i, shot in enumerate(split_shots):
#                     split_s = shot.split('\n')
#                     message.append({'role': 'user', 'content': split_s[0].strip()})
#                     if (i + 1) < len(split_shots):
#                         message.append({'role': 'assistant', 'content': '\n'.join(split_s[1:]).strip()})
#             else:
#                 message = [{"role": "user", "content": prompt}]
#             return call_fn(message)

#     def call_ChatGPT(self, message, max_len=1024, max_new_tokens=512):
#         api_key = os.getenv("OPENAI_API_KEY")
#         client = OpenAI(api_key=api_key)

#         num_rate_errors = 0
#         while True:
#             try:
#                 response = client.chat.completions.create(
#                     model=self.model_name,
#                     messages=message,
#                     max_tokens=max_new_tokens,
#                     temperature=self.temp
#                 )
#                 return response.choices[0].message.content, response
#             except Exception as e:
#                 num_rate_errors += 1
#                 if isinstance(e, openai.BadRequestError):
#                     logging.critical(f"InvalidRequestError\nPrompt:\n{message}")
#                     raise
#                 logging.error(f"API error: {e} ({num_rate_errors}). Waiting {2 ** num_rate_errors}s")
#                 time.sleep(2 ** num_rate_errors)
#                 if num_rate_errors > MAX_NUM_ERROR:
#                     return None, None

#     def call_GPT3(self, message, max_len=512, max_tokens=512, num_log_probs=0, echo=False):
#         api_key = os.getenv("OPENAI_API_KEY")
#         client = OpenAI(api_key=api_key)
#         prompt = message[-1]['content']

#         num_rate_errors = 0
#         while True:
#             try:
#                 response = client.completions.create(
#                     model=self.model_name,
#                     prompt=prompt,
#                     max_tokens=max_tokens,
#                     temperature=self.temp,
#                     logprobs=num_log_probs,
#                     echo=echo
#                 )
#                 return response.choices[0].text, response
#             except Exception as e:
#                 num_rate_errors += 1
#                 if isinstance(e, openai.BadRequestError):
#                     logging.critical(f"InvalidRequestError\nPrompt:\n{prompt}")
#                     raise
#                 logging.error(f"API error: {e} ({num_rate_errors}). Waiting {2 ** num_rate_errors}s")
#                 time.sleep(2 ** num_rate_errors)
#                 if num_rate_errors > MAX_NUM_ERROR:
#                     return None, None

